#Import statements

import pandas as pd
from datetime import datetime
from sklearn import metrics
import textdistance
import itertools
import numpy as np
import pickle as pkl
import matplotlib.pyplot as plt
import random
import time

from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV, train_test_split
from sklearn.metrics import accuracy_score
from sklearn.preprocessing import FunctionTransformer
from sklearn.pipeline import Pipeline
from sklearn.model_selection import KFold
from sklearn.metrics import f1_score
from sklearn.metrics import roc_curve, auc
from functions import *


#######################################################
#Preprocessing
#######################################################

random.seed(50)
np.random.seed(50)
start_time = time.time()
#Loading in hand-labeled pairs for training set and a few extra org names for generating negative pairs
hand_labeled = pd.read_csv('hand_labeled_matches.csv')
unrepresented = pkl.load(open("unrepresented.p", "rb"))

print("starting...")
#Shuffling Hand-Labeled and Extra Org Names to Generate 71,043 String Pair Training Set
def many_to_one_shuffle(amicus_orgs, bonica_orgs, additional_amicus):
    '''
    @input amicus_orgs: list of the 357 amicus organization names from the hand-labeled data
    @input bonica_orgs: list of the 357 bonica organization names from the hand-labeled data
    @input additional_amicus: the 11 additional organization names to help generate more training data
    
    @output a dataframe of size (71043, 8) consisting of all labeled string pairs used in the training set
        and their corresponding string distance metrics
    '''
    n = len(bonica_orgs)
    amicus_list, bonica_list, labels = amicus_orgs,bonica_orgs,[1]*n
    pairs = {(amicus_orgs[i], bonica_orgs[i]) for i in range(n)}
    unique_amicus = list(set(amicus_orgs)) + additional_amicus
    for amicus in unique_amicus:
        for bonica in bonica_orgs:
            pair = (amicus, bonica)
            if not pair in pairs:
                pairs.add(pair)
                amicus_list.append(amicus)
                bonica_list.append(bonica)
                labels.append(0)
    
    #computing the string metrics
    output = precompute_distances(amicus_list, bonica_list, dump_duplicates=False,combination_type='standard')
    output['label'] = labels
                
    return output
            

#Base Training is the training set we start with, it has 71,043 labeled string pairs, 357 of them matches
#random.seed(30)
#np.random.seed(30)
base_training = many_to_one_shuffle(list(hand_labeled['amicus']), list(hand_labeled['bonica']), unrepresented)

#base_training.to_csv('full_train_set.csv')

#######################################################
#Intiating Data Structures Before First Iteration
#######################################################

#we did 10 human-in-the-loop iterations for this task
num_iterations = 10

#total_train_set is a table which will keep track of the labeled string pairs from each iteration
#starting as an empty dataframe, it will add the newly labeled pairs from each iteration
total_train_set = pd.DataFrame({})

#feature_importances is a 11x5 array which will story the feature importances for each iteration.
#Each row is an iteration (including the first row obtained from the base model), and each column
#represents a string metric
feature_importances = np.ones((num_iterations+1,5))

#num_positives and num_negatives will be lists which track how many positive/ negative data points
#are added in each HITL iteration. The will end up the same length as the number of iterations
num_positives = []
num_negatives = []

#amicus_orgs is a list of 13,939 organization names/
amicus_orgs = list(pd.read_csv('amicus_org_names.csv')['amicus'])

#bonica_orgs is a list of 1,332,470 organization names
bonica_orgs = list(pd.read_csv('bonica_org_names.csv')['x'])

#base_training = pd.read_csv('full_train_set.csv')
num_positives.append(base_training.loc[base_training['label'] == 1].shape[0])
num_negatives.append(base_training.loc[base_training['label'] == 0].shape[0])

#random.seed(30)
#np.random.seed(30)

#creating the training sets for our baseline model trained on just the base training set, before we perform
#any human-in-the-loop iterations, and then fitting that model
X_train = base_training[['cosine', 'jaccard', 'levenshtein', 'lcsstr', 'overlap']]
y_train = base_training['label']
base_model, feat_imports = train(X_train, y_train)

#saving the feature importances
feature_importances[0] = feat_imports

#saving the model
pkl.dump(base_model, open("iter0_model.p", "wb"))

total_train_set = base_training

#######################################################
#The HITL Iterations
#######################################################

#iterating through 10 iterations
for i in range(num_iterations):
    print('iteration', i)
    #setting seed so that its different for each iterations
    #random.seed(30+i)
    #np.random.seed(30+i)
    
    #Loading in model from previous iteration
    this_iter_model = pkl.load(open("iter"+str(i)+"_model.p", "rb"))
    
    #the amicus_bonica_iteration function from functions.py finds a sample of 27,878,000 and uses 
    #the model from the previous iteration to rank the pairs in order from most likely to be a match
    #to the most likely to be a match. It then asks the user about the top 100 most likely string pairs
    #and whether those pairs are actually a match. It returns a labeled table of these 100 string pairs.
    if False:
        this_iter = amicus_bonica_iteration(2000, 100, this_iter_model, amicus_orgs, bonica_orgs)
    else:
        bonica_sample = np.random.choice(bonica_orgs, size=2000, replace=False)
        pairs_df = precompute_distances(amicus_orgs, bonica_sample)
        pairs_df = get_predictions(this_iter_model, pairs_df, 'basic_score')
        best_matches = pairs_df.head(100)
        print('Human in the Loop Step:')
        print('Reading in pre-labeled pairs from this iteration...')
        this_iter = pd.read_csv('iter'+str(i+1)+'_pairs.csv')
    
    num_pos = this_iter.loc[this_iter['label'] == 1].shape[0]
    num_neg = (num_pos**2-num_pos) + (100- num_pos)
    num_positives.append(num_pos)
    num_negatives.append(num_neg)
    
    #saving this iteration's labeled string pairs as a CSV
    this_iter.to_csv('iter'+str(i+1)+'_pairs.csv', index=False)
    
    total_train_set = base_training
    for j in range(1,i+2):
        shuffled_iter = shuffle(pd.read_csv('iter'+str(j)+'_pairs.csv'))
        total_train_set = pd.concat([total_train_set, shuffled_iter])
    
    #building and training this iterations model
    #random.seed(30+i)
    #np.random.seed(30+i)
    X_train = total_train_set[['cosine', 'jaccard', 'levenshtein', 'lcsstr', 'overlap']]
    y_train = total_train_set['label']

    model, feat_imports = train(X_train, y_train)

    #random.seed(30+i)
    #np.random.seed(30+i)
    
    #Saving this iterations model and recording feature importances
    pkl.dump(model, open("iter" + str(i+1) +"_model.p", "wb"))
    feature_importances[i+1] = feat_imports

#Saving feature importances and to a CSV
feats = pd.DataFrame(feature_importances, columns=['cosine', 'jaccard', 'levenshtein', 'lcsstr', 'overlap'])
feats['Iteration Number'] = np.array(range(num_iterations+1))
feats['num_matches'] = num_positives[:num_iterations+1]
feats['num_negatives'] = num_negatives[:num_iterations+1]
feats['Task'] = (num_iterations+1)*['Interest Group Ideology']
feats.to_csv('feature_importances_by_iteration.csv', index=False)

#######################################################
#Calculating AUC Values
#######################################################

#loading a hand-labeled test set of 4000 string pairs
test_set = pd.read_csv('evaluation_set.csv')

#aucs is a list which will track the auc values of each iteration
aucs = []
for i in range(num_iterations+1):
    #loading this model's iteration
    model = pkl.load(open("iter"+str(i)+"_model.p", "rb"))
    
    #the get_predictions function for functions.py applies the passed in model to the passed in the test set,
    #and returns a table with the string pairs and their coresponding model scores under the passed-in column name
    pred_table = get_predictions(model, test_set, 'model score')
    
    #calculating the AUC for this iteration and appending it to the aucs list
    fpr, tpr, thresholds = metrics.roc_curve(pred_table['label'], pred_table['model score'], drop_intermediate=False)
    this_auc = auc(fpr, tpr)
    aucs.append(this_auc)

#Saving AUC values to a csv
test_set = test_set.drop('model score', axis=1)
pd.DataFrame({'Iteration #': range(11),'AUC': aucs}).to_csv('AUC_table.csv',index=False)


#######################################################
#Generating Final_Results.csv table
#######################################################

#finding index corresponding to the best AUC performance
best_auc_index = np.argmax(aucs)

#loading the model corresponding to that iteration
best_model = pkl.load(open("iter"+str(best_auc_index)+"_model.p", "rb"))

#loading the basic model from before any HITL iterations
basic_model = pkl.load(open("iter0_model.p", "rb"))

#finding best model and basic model predictions, saving as CSV
final_results = get_predictions(basic_model, test_set, 'basic_score')
final_results = get_predictions(best_model, final_results, 'hitl_score')
final_results.to_csv('Final_Results.csv', index=False)

end_time = time.time()
print('')
print('total run time:')
print(end_time-start_time)